{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 348,
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "import torch"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 26,
   "metadata": {},
   "outputs": [],
   "source": [
    "M = np.arange(1,26)\n",
    "np.random.shuffle(M)\n",
    "M = torch.tensor(M.reshape(5,5))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 27,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([[ 8, 11, 17, 13, 14],\n",
       "        [ 7, 10, 19,  5, 12],\n",
       "        [15,  1, 23,  9, 20],\n",
       "        [ 4,  2, 25, 16,  3],\n",
       "        [21, 18, 22,  6, 24]])"
      ]
     },
     "execution_count": 27,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "M"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "metadata": {},
   "outputs": [],
   "source": [
    "# what is contiguous doing anyway?\n",
    "# some error brought up in torch - don't remember now exactly the cause\n",
    "N  = M.contiguous()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 36,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor(5)"
      ]
     },
     "execution_count": 36,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "tau = 0.2\n",
    "kth = int(tau * np.prod(M.shape))\n",
    "# kth value aligs, no need to sort\n",
    "theta, _ = torch.kthvalue(M.view(-1), kth)\n",
    "theta"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 37,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([[ True,  True,  True,  True,  True],\n",
       "        [ True,  True,  True, False,  True],\n",
       "        [ True, False,  True,  True,  True],\n",
       "        [False, False,  True,  True, False],\n",
       "        [ True,  True,  True,  True,  True]])"
      ]
     },
     "execution_count": 37,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "M > theta"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "The ones that should stay - are the ones with highest correlation. But while the weight keep mask makes sense, keep the highest inweight "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 30,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([23, 25, 21, 22, 24])"
      ]
     },
     "execution_count": 30,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "M[M>theta]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(torch.Size([10, 10]), torch.Size([10, 10]))"
      ]
     },
     "execution_count": 19,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# I'm only keeping the top 20%\n",
    "# what is the effect when I add the weight mask?"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 32,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([[ True,  True,  True,  True,  True],\n",
       "        [ True,  True,  True, False,  True],\n",
       "        [ True, False,  True,  True,  True],\n",
       "        [False, False,  True,  True, False],\n",
       "        [ True,  True,  True, False,  True]])"
      ]
     },
     "execution_count": 32,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# calculate weight mask\n",
    "weight = M\n",
    "zeta = 0.3\n",
    "weight_pos = weight[weight > 0]\n",
    "pos_threshold, _ = torch.kthvalue(\n",
    "    weight_pos, max(int(zeta * len(weight_pos)), 1)\n",
    ")\n",
    "weight_keep_mask = (weight >= pos_threshold)\n",
    "weight_keep_mask"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 34,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([ 8, 11, 17, 13, 14,  7, 10, 19, 12, 15, 23,  9, 20, 25, 16, 21, 18, 22,\n",
       "        24])"
      ]
     },
     "execution_count": 34,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "M[weight_keep_mask]"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Testing hebbian only"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 183,
   "metadata": {},
   "outputs": [],
   "source": [
    "M = np.arange(1,26)\n",
    "np.random.shuffle(M)\n",
    "M = torch.tensor(M.reshape(5,5))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 184,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(tensor([[0.3201, 0.8318, 0.3382, 0.9734, 0.0985],\n",
       "         [0.0401, 0.8620, 0.0845, 0.3778, 0.3996],\n",
       "         [0.4954, 0.0092, 0.6713, 0.8594, 0.9487],\n",
       "         [0.8101, 0.0922, 0.2033, 0.7185, 0.4588],\n",
       "         [0.3897, 0.6865, 0.5072, 0.9749, 0.0597]]),\n",
       " tensor([[19, 21, 12,  0,  0],\n",
       "         [ 0,  0,  0,  0,  0],\n",
       "         [10, 25,  8,  0,  0],\n",
       "         [ 2, 11,  7,  0,  0],\n",
       "         [14, 18,  6,  0,  0]]),\n",
       " 12)"
      ]
     },
     "execution_count": 184,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "weight = M\n",
    "corr = torch.rand((5,5))\n",
    "weight[1, :] = 0\n",
    "weight[:, 3] = 0\n",
    "weight[:, 4] = 0\n",
    "num_params = torch.sum(weight != 0).item()\n",
    "corr, weight, num_params"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 350,
   "metadata": {},
   "outputs": [],
   "source": [
    "def prune(weight, num_params, corr, tau, idx=0, hebbian_grow=True):\n",
    "    with torch.no_grad():\n",
    "\n",
    "        # print(\"corr dimension\", corr.shape)\n",
    "        # print(\"weight dimension\", weight.shape)\n",
    "\n",
    "        # transpose to fit the weights, and eliminate zero weight\n",
    "        num_synapses = np.prod(corr.shape)\n",
    "        # corr = corr.t()\n",
    "        active_synapses = (weight != 0)\n",
    "        nonactive_synapses = (weight == 0)\n",
    "        total_active = torch.sum(active_synapses).item()\n",
    "        total_nonactive = torch.sum(nonactive_synapses).item()\n",
    "\n",
    "        corr_active = corr[active_synapses]\n",
    "        # decide which weights to remove based on correlation\n",
    "        kth = int(tau * total_active)\n",
    "        print(\"total active: \", total_active)\n",
    "        print(\"kth: \", kth)\n",
    "        # if kth = 0, keep all the synapses\n",
    "        if kth == 0:\n",
    "            hebbian_keep_mask = active_synapses\n",
    "        # else if kth greater than shape, remove all synapses\n",
    "        elif kth >= num_synapses:\n",
    "            hebbian_keep_mask = torch.zeros(corr.shape)\n",
    "        # if no edge cases\n",
    "        else:\n",
    "            keep_threshold, _ = torch.kthvalue(corr_active, kth)\n",
    "            print(keep_threshold)\n",
    "            # keep mask are ones above threshold and currently active\n",
    "            hebbian_keep_mask = (corr > keep_threshold) & active_synapses\n",
    "            \n",
    "        # keep_mask = weight_keep_mask & hebbian_keep_mask\n",
    "        keep_mask = hebbian_keep_mask\n",
    "        num_add = max(num_params - torch.sum(keep_mask).item(), 0)  \n",
    "\n",
    "        # added option to have hebbian grow or not\n",
    "        if hebbian_grow:\n",
    "            # get threshold\n",
    "            kth = total_nonactive - num_add\n",
    "            corr_nonactive = corr[nonactive_synapses]\n",
    "            add_threshold, _ = torch.kthvalue(corr_nonactive, kth)\n",
    "            # calculate mask, only for currently nonactive\n",
    "            add_mask = (corr > add_threshold) & nonactive_synapses\n",
    "        else:\n",
    "            # probability of adding is 1 or lower\n",
    "            p_add = num_add / max(total_nonactive, num_add)\n",
    "            random_sample = torch.rand(num_synapses) < p_add\n",
    "            add_mask = random_sample & nonactive_synapses\n",
    "\n",
    "        # calculate the new mask\n",
    "        new_mask = keep_mask | add_mask\n",
    "\n",
    "    # track added connections\n",
    "    return new_mask, keep_mask, add_mask\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 355,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "total active:  12\n",
      "kth:  3\n",
      "tensor(0.2033)\n"
     ]
    }
   ],
   "source": [
    "new_mask, keep_mask, add_mask = prune(weight, num_params, corr, tau=0.25)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 352,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([[ 19,  21, -12,   0,   0],\n",
       "        [  0,   0,   0,   0,   0],\n",
       "        [ 10,  25,   8,   0,   0],\n",
       "        [  2, -11,   7,   0,   0],\n",
       "        [ 14,  18,  -6,   0,   0]])"
      ]
     },
     "execution_count": 352,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "weight"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 353,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([[0.3201, 0.8318, 0.3382, 0.9734, 0.0985],\n",
       "        [0.0401, 0.8620, 0.0845, 0.3778, 0.3996],\n",
       "        [0.4954, 0.0092, 0.6713, 0.8594, 0.9487],\n",
       "        [0.8101, 0.0922, 0.2033, 0.7185, 0.4588],\n",
       "        [0.3897, 0.6865, 0.5072, 0.9749, 0.0597]])"
      ]
     },
     "execution_count": 353,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "corr"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 157,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(tensor([[False,  True,  True, False, False],\n",
       "         [False, False, False, False, False],\n",
       "         [ True,  True,  True, False, False],\n",
       "         [ True, False,  True, False, False],\n",
       "         [ True, False,  True, False, False]]), tensor(9))"
      ]
     },
     "execution_count": 157,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# no prunning happened\n",
    "keep_mask, torch.sum(keep_mask)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 158,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(tensor([[False, False, False, False, False],\n",
       "         [False, False, False,  True, False],\n",
       "         [False, False, False, False, False],\n",
       "         [False, False, False,  True, False],\n",
       "         [False, False, False, False,  True]]), tensor(3))"
      ]
     },
     "execution_count": 158,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# adding almost all of the new items\n",
    "add_mask, torch.sum(add_mask)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 159,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(tensor([[False,  True,  True, False, False],\n",
       "         [False, False, False,  True, False],\n",
       "         [ True,  True,  True, False, False],\n",
       "         [ True, False,  True,  True, False],\n",
       "         [ True, False,  True, False,  True]]), tensor(12))"
      ]
     },
     "execution_count": 159,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "new_mask, torch.sum(new_mask)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Testing hebbian + weights"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 171,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([[21,  9, 12,  7,  6],\n",
       "        [14, 10, 17,  2,  8],\n",
       "        [ 5, 25, 13, 11,  1],\n",
       "        [ 4,  3, 24, 15, 16],\n",
       "        [20, 18, 23, 22, 19]])"
      ]
     },
     "execution_count": 171,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "N = np.arange(1,26)\n",
    "np.random.shuffle(N)\n",
    "N = torch.tensor(N.reshape(5,5))\n",
    "N"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 178,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(torch.return_types.kthvalue(\n",
       " values=tensor(5),\n",
       " indices=tensor(10)), torch.return_types.kthvalue(\n",
       " values=tensor(25),\n",
       " indices=tensor(11)))"
      ]
     },
     "execution_count": 178,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "torch.kthvalue(N.view(-1), 5)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 179,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "torch.return_types.kthvalue(\n",
       "values=tensor(25),\n",
       "indices=tensor(11))"
      ]
     },
     "execution_count": 179,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "torch.kthvalue(N.view(-1), 25)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 185,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(tensor([[0.9703, 0.9620, 0.5798, 0.8359, 0.0570],\n",
       "         [0.6940, 0.1080, 0.2981, 0.9239, 0.3559],\n",
       "         [0.6024, 0.7168, 0.1934, 0.4220, 0.1958],\n",
       "         [0.7452, 0.7896, 0.7346, 0.5306, 0.1022],\n",
       "         [0.3658, 0.7152, 0.4189, 0.8674, 0.2408]]),\n",
       " tensor([[19, 21, 12,  0,  0],\n",
       "         [ 0,  0,  0,  0,  0],\n",
       "         [10, 25,  8,  0,  0],\n",
       "         [ 2, 11,  7,  0,  0],\n",
       "         [14, 18,  6,  0,  0]]),\n",
       " 12)"
      ]
     },
     "execution_count": 185,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "weight = M\n",
    "corr = torch.rand((5,5))\n",
    "weight[1, :] = 0\n",
    "weight[:, 3] = 0\n",
    "weight[:, 4] = 0\n",
    "num_params = torch.sum(weight != 0).item()\n",
    "corr, weight, num_params"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 360,
   "metadata": {},
   "outputs": [],
   "source": [
    "def prune(weight, num_params, corr, tau, zeta, idx=0, hebbian_grow=True):\n",
    "    with torch.no_grad():\n",
    "\n",
    "        # print(\"corr dimension\", corr.shape)\n",
    "        # print(\"weight dimension\", weight.shape)\n",
    "\n",
    "        # transpose to fit the weights, and eliminate zero weight\n",
    "        num_synapses = np.prod(weight.shape)\n",
    "        active_synapses = (weight != 0)\n",
    "        nonactive_synapses = (weight == 0)\n",
    "        total_active = torch.sum(active_synapses).item()\n",
    "        total_nonactive = torch.sum(nonactive_synapses).item()\n",
    "\n",
    "        # ----------- HEBBIAN PRUNING ----------------\n",
    "        \n",
    "        if tau is not None:\n",
    "            # corr = corr.t()\n",
    "            corr_active = corr[active_synapses]\n",
    "            # decide which weights to remove based on correlation\n",
    "            kth = int(tau * total_active)\n",
    "            print(\"total active: \", total_active)\n",
    "            print(\"kth: \", kth)\n",
    "            # if kth = 0, keep all the synapses\n",
    "            if kth == 0:\n",
    "                hebbian_keep_mask = active_synapses\n",
    "            # else if kth greater than shape, remove all synapses\n",
    "            elif kth >= num_synapses:\n",
    "                hebbian_keep_mask = torch.zeros(weight.shape)\n",
    "            # if no edge cases\n",
    "            else:\n",
    "                keep_threshold, _ = torch.kthvalue(corr_active, kth)\n",
    "                print(keep_threshold)\n",
    "                # keep mask are ones above threshold and currently active\n",
    "                hebbian_keep_mask = (corr > keep_threshold) & active_synapses\n",
    "\n",
    "        # ----------- WEIGHT PRUNING ----------------\n",
    "                        \n",
    "        if zeta is not None:\n",
    "            # calculate the positive\n",
    "            weight_pos = weight[weight > 0]\n",
    "            pos_kth = int(zeta * len(weight_pos))\n",
    "            if pos_kth == 0:\n",
    "                pos_threshold = -1\n",
    "            else:\n",
    "                pos_threshold, _ = torch.kthvalue(weight_pos, pos_kth)\n",
    "            print(pos_kth, pos_threshold)\n",
    "            \n",
    "            # calculate the negative\n",
    "            weight_neg = weight[weight < 0]\n",
    "            neg_kth = int((1-zeta) * len(weight_neg))\n",
    "            if neg_kth == 0:\n",
    "                neg_threshold = 1\n",
    "            else:\n",
    "                neg_threshold, _ = torch.kthvalue(weight_neg, neg_kth)\n",
    "            print(neg_kth, neg_threshold)                \n",
    "                \n",
    "            partial_weight_mask = (weight > pos_threshold) | (weight <= neg_threshold)\n",
    "            weight_mask = partial_weight_mask & active_synapses\n",
    "\n",
    "        # ----------- COMBINE HEBBIAN AND WEIGHT ----------------            \n",
    "            \n",
    "        # join both masks\n",
    "        if tau and zeta:\n",
    "            keep_mask = hebbian_keep_mask | weight_mask\n",
    "        elif tau:\n",
    "            keep_mask = hebbian_keep_mask\n",
    "        elif zeta:\n",
    "            keep_mask = weight_mask\n",
    "        else:\n",
    "            keep_mask = active_synapses\n",
    "\n",
    "        # ----------- GROWTH ----------------            \n",
    "\n",
    "        # calculate number of params removed to be readded\n",
    "        num_add = max(num_params - torch.sum(keep_mask).item(), 0)\n",
    "        print(num_add)\n",
    "        # added option to have hebbian grow or not\n",
    "        if hebbian_grow:\n",
    "            # get threshold\n",
    "            kth = total_nonactive - num_add\n",
    "            corr_nonactive = corr[nonactive_synapses]\n",
    "            add_threshold, _ = torch.kthvalue(corr_nonactive, kth)\n",
    "            # calculate mask, only for currently nonactive\n",
    "            add_mask = (corr > add_threshold) & nonactive_synapses\n",
    "        else:\n",
    "            # probability of adding is 1 or lower\n",
    "            p_add = num_add / max(total_nonactive, num_add)\n",
    "            print(p_add)\n",
    "            random_sample = torch.rand(num_synapses) < p_add\n",
    "            add_mask = random_sample & nonactive_synapses\n",
    "\n",
    "        # calculate the new mask\n",
    "        new_mask = keep_mask | add_mask\n",
    "\n",
    "    # track added connections\n",
    "    return new_mask, keep_mask, add_mask\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 357,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(tensor([[0.7310, 0.6574, 0.8158, 0.6538, 0.5781],\n",
       "         [0.7679, 0.7877, 0.4430, 0.9151, 0.9181],\n",
       "         [0.0374, 0.2257, 0.6370, 0.9288, 0.2408],\n",
       "         [0.5740, 0.7109, 0.2618, 0.2762, 0.1792],\n",
       "         [0.2557, 0.9166, 0.7356, 0.7074, 0.0236]]),\n",
       " tensor([[19, 21, 12,  0,  0],\n",
       "         [ 0,  0,  0,  0,  0],\n",
       "         [10, 25,  8,  0,  0],\n",
       "         [ 2, 11,  7,  0,  0],\n",
       "         [14, 18,  6,  0,  0]]),\n",
       " 12)"
      ]
     },
     "execution_count": 357,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "weight = M\n",
    "corr = torch.rand((5,5))\n",
    "weight[1, :] = 0\n",
    "weight[:, 3] = 0\n",
    "weight[:, 4] = 0\n",
    "num_params = torch.sum(weight != 0).item()\n",
    "corr, weight, num_params"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 387,
   "metadata": {},
   "outputs": [],
   "source": [
    "# coactivation matrix\n",
    "corr = torch.tensor([ [0.3201, 0.8318, 0.3382, 0.9734, 0.0985],\n",
    "                [0.0401, 0.8620, 0.0845, 0.3778, 0.3996],\n",
    "                [0.4954, 0.0092, 0.6713, 0.8594, 0.9487],\n",
    "                [0.8101, 0.0922, 0.2033, 0.7185, 0.4588],\n",
    "                [0.3897, 0.6865, 0.5072, 0.9749, 0.0597]])\n",
    "# weight matrix\n",
    "weight = torch.tensor([\n",
    "                 [19, 21, -12,  0,  0],\n",
    "                 [ 0,  0,  0,  0,  0],\n",
    "                 [-10, 25, -8,  0,  0],\n",
    "                 [ 2, -11,  7,  0,  0],\n",
    "                 [-14, 18,  -6,  0,  0]])\n",
    "num_params = torch.sum(weight != 0).item()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 388,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "total active:  12\n",
      "kth:  3\n",
      "tensor(0.2033)\n",
      "3 tensor(18)\n",
      "3 tensor(-11)\n",
      "1\n"
     ]
    }
   ],
   "source": [
    "new_mask, keep_mask, add_mask = prune(weight, num_params, corr, tau=0.25, zeta=0.50)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 389,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([[ True,  True,  True, False, False],\n",
       "        [False, False, False, False, False],\n",
       "        [ True,  True,  True, False, False],\n",
       "        [ True,  True, False, False, False],\n",
       "        [ True,  True,  True, False, False]])"
      ]
     },
     "execution_count": 389,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "keep_mask"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 386,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([[False, False, False,  True, False],\n",
       "        [False, False, False, False, False],\n",
       "        [False, False, False, False, False],\n",
       "        [False, False, False, False, False],\n",
       "        [False, False, False,  True, False]])"
      ]
     },
     "execution_count": 386,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "add_mask"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 379,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([[ True,  True,  True, False, False],\n",
       "        [False, False, False, False, False],\n",
       "        [ True,  True,  True, False, False],\n",
       "        [ True,  True, False, False, False],\n",
       "        [ True,  True,  True,  True, False]])"
      ]
     },
     "execution_count": 379,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "new_mask"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Flipping the logic"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 336,
   "metadata": {},
   "outputs": [],
   "source": [
    "def prune(weight, num_params, corr, tau, zeta, idx=0, hebbian_grow=True):\n",
    "    with torch.no_grad():\n",
    "\n",
    "        # print(\"corr dimension\", corr.shape)\n",
    "        # print(\"weight dimension\", weight.shape)\n",
    "\n",
    "        # transpose to fit the weights, and eliminate zero weight\n",
    "        num_synapses = np.prod(weight.shape)\n",
    "        active_synapses = (weight != 0)\n",
    "        nonactive_synapses = (weight == 0)\n",
    "        total_active = torch.sum(active_synapses).item()\n",
    "        total_nonactive = torch.sum(nonactive_synapses).item()\n",
    "\n",
    "        # ----------- HEBBIAN PRUNING ----------------\n",
    "        \n",
    "        if tau is not None:\n",
    "            # corr = corr.t()\n",
    "            corr_active = corr[active_synapses]\n",
    "            # decide which weights to remove based on correlation\n",
    "            kth = int((1-tau) * total_active)\n",
    "            print(\"total active: \", total_active)\n",
    "            print(\"kth: \", kth)\n",
    "            # if kth = 0, keep all the synapses\n",
    "            if kth == 0:\n",
    "                hebbian_keep_mask = torch.zeros(weight.shape).bool()\n",
    "            # else if kth greater than shape, remove all synapses\n",
    "            elif kth >= num_synapses:\n",
    "                hebbian_keep_mask = active_synapses\n",
    "            # if no edge cases\n",
    "            else:\n",
    "                keep_threshold, _ = torch.kthvalue(corr_active, kth)\n",
    "                print(keep_threshold)\n",
    "                # keep mask are ones above threshold and currently active\n",
    "                hebbian_keep_mask = (corr <= keep_threshold) & active_synapses\n",
    "            print(\"hebbian_keep_mask\",  hebbian_keep_mask)\n",
    "\n",
    "        # ----------- WEIGHT PRUNING ----------------\n",
    "                        \n",
    "        if zeta is not None:\n",
    "            \n",
    "            # calculate the positive\n",
    "            weight_pos = weight[weight > 0]\n",
    "            pos_kth = int(zeta * len(weight_pos))\n",
    "            # if no positive weight, threshold can be 0 (select none)\n",
    "            if len(weight_pos) > 0:\n",
    "                # if zeta=0, pos_kth=0, prune nothing\n",
    "                if pos_kth == 0:\n",
    "                    pos_threshold = -1\n",
    "                else:\n",
    "                    pos_threshold, _ = torch.kthvalue(weight_pos, pos_kth)\n",
    "            else:\n",
    "                pos_threshold = 0\n",
    "\n",
    "            # calculate the negative\n",
    "            weight_neg = weight[weight < 0]\n",
    "            neg_kth = int((1-zeta) * len(weight_neg))\n",
    "            # if no negative weight, threshold -1 (select none)\n",
    "            if len(weight_neg) > 0:\n",
    "                # if zeta=1, neg_kth=0, prune all\n",
    "                if neg_kth == 0:\n",
    "                    neg_threshold = torch.min(weight_neg).item() - 1\n",
    "                else:\n",
    "                    neg_threshold, _ = torch.kthvalue(weight_neg, neg_kth)\n",
    "            else:\n",
    "                neg_threshold = -1\n",
    "\n",
    "            partial_weight_mask = (weight > pos_threshold) | (weight <= neg_threshold)\n",
    "            weight_mask = partial_weight_mask & active_synapses\n",
    "            print(\"weight_mask\", weight_mask)\n",
    "\n",
    "        # ----------- COMBINE HEBBIAN AND WEIGHT ----------------            \n",
    "            \n",
    "        # join both masks\n",
    "        if tau and zeta:\n",
    "            keep_mask = hebbian_keep_mask | weight_mask\n",
    "        elif tau:\n",
    "            keep_mask = hebbian_keep_mask\n",
    "        elif zeta:\n",
    "            keep_mask = weight_mask\n",
    "        else:\n",
    "            keep_mask = active_synapses\n",
    "\n",
    "        # ----------- GROWTH ----------------            \n",
    "\n",
    "        # calculate number of params removed to be readded\n",
    "        num_add = max(num_params - torch.sum(keep_mask).item(), 0)\n",
    "        print(num_add)\n",
    "        # added option to have hebbian grow or not\n",
    "        if hebbian_grow:\n",
    "            # get threshold\n",
    "            kth = num_add # should not be non-int\n",
    "            if kth > 0:\n",
    "                corr_nonactive = corr[nonactive_synapses]\n",
    "                add_threshold, _ = torch.kthvalue(corr_nonactive, kth)\n",
    "                # calculate mask, only for currently nonactive\n",
    "                add_mask = (corr <= add_threshold) & nonactive_synapses\n",
    "            # if there is nothing to add, return zeros\n",
    "            else:\n",
    "                add_mask = torch.zeros(weight.shape).bool()\n",
    "        else:\n",
    "            # probability of adding is 1 or lower\n",
    "            p_add = num_add / max(total_nonactive, num_add)\n",
    "            print(p_add)\n",
    "            random_sample = torch.rand(num_synapses) < p_add\n",
    "            add_mask = random_sample & nonactive_synapses\n",
    "\n",
    "        # calculate the new mask\n",
    "        new_mask = keep_mask | add_mask\n",
    "\n",
    "    # track added connections\n",
    "    return new_mask, keep_mask, add_mask\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 337,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(tensor([[0.2943, 0.4472, 0.6258, 0.5547, 0.6841],\n",
       "         [0.7235, 0.9685, 0.5549, 0.5836, 0.1360],\n",
       "         [0.7461, 0.2407, 0.3790, 0.2589, 0.1135],\n",
       "         [0.6667, 0.4287, 0.4017, 0.2251, 0.2324],\n",
       "         [0.2342, 0.6125, 0.4358, 0.9662, 0.5876]]),\n",
       " tensor([[19, 21, 12,  0,  0],\n",
       "         [ 0,  0,  0,  0,  0],\n",
       "         [10, 25,  8,  0,  0],\n",
       "         [ 2, 11,  7,  0,  0],\n",
       "         [14, 18,  6,  0,  0]]),\n",
       " 12)"
      ]
     },
     "execution_count": 337,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "weight = M\n",
    "corr = torch.rand((5,5))\n",
    "weight[1, :] = 0\n",
    "weight[:, 3] = 0\n",
    "weight[:, 4] = 0\n",
    "num_params = torch.sum(weight != 0).item()\n",
    "corr, weight, num_params"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 343,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "0\n"
     ]
    }
   ],
   "source": [
    "# new_mask, keep_mask, add_mask = prune(weight, num_params, corr, tau=1, zeta=1)\n",
    "# new_mask, keep_mask, add_mask = prune(weight, num_params, corr, tau=1, zeta=None)\n",
    "# new_mask, keep_mask, add_mask = prune(weight, num_params, corr, tau=None, zeta=1)\n",
    "new_mask, keep_mask, add_mask = prune(weight, num_params, corr, tau=None, zeta=None)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 344,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([[ True,  True,  True, False, False],\n",
       "        [False, False, False, False, False],\n",
       "        [ True,  True,  True, False, False],\n",
       "        [ True,  True,  True, False, False],\n",
       "        [ True,  True,  True, False, False]])"
      ]
     },
     "execution_count": 344,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "keep_mask"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 345,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([[False, False, False, False, False],\n",
       "        [False, False, False, False, False],\n",
       "        [False, False, False, False, False],\n",
       "        [False, False, False, False, False],\n",
       "        [False, False, False, False, False]])"
      ]
     },
     "execution_count": 345,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "add_mask"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 346,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([[ True,  True,  True, False, False],\n",
       "        [False, False, False, False, False],\n",
       "        [ True,  True,  True, False, False],\n",
       "        [ True,  True,  True, False, False],\n",
       "        [ True,  True,  True, False, False]])"
      ]
     },
     "execution_count": 346,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "new_mask"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.3"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
